package com.mimo.admin.shiro.filter;

import java.util.List;

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import org.apache.shiro.web.filter.AccessControlFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.mimo.admin.constants.HttpParams;

/**
 * 访问资源时，根据Referer进行防盗链处理
 */
public class CustomAccessControlFilter extends AccessControlFilter {

  private static final Logger logger = LoggerFactory.getLogger(CustomAccessControlFilter.class);

  private static final String ACTUATOR_URI = "/actuator/info";
  private static final String ROOT_URI = "/";

  // 合法域名
  private List<String> domains;

  // 是否开启拦截，测试环境可关闭
  private boolean csrfEnable;

  @Override
  protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
    if (!csrfEnable) {
      return true;
    }

    HttpServletRequest httpRequest = (HttpServletRequest) request;
    // 获取请求是从哪里来的
    String referer = httpRequest.getHeader(HttpParams.REFERER);

    if (!isWhiteURL(request, response) && (referer == null || domains.stream().noneMatch(d -> referer.startsWith(d)))) {

      saveRequestAndRedirectToLogin(request, response);

      return false;
    }

    return true;
  }

  @Override
  protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
    if (!csrfEnable) {
      return true;
    }

    HttpServletRequest httpRequest = (HttpServletRequest) request;
    // 获取请求是从哪里来的
    String referer = httpRequest.getHeader(HttpParams.REFERER);
    // 如果是直接输入的地址，或者不是从本网站访问的重定向到本网站的首页

    if (!isWhiteURL(request, response) && (referer == null || domains.stream().noneMatch(d -> referer.startsWith(d)))) {
      logger.warn("拦截到非法请求，Referer={},RequestURI={},", referer, httpRequest.getRequestURI());

      return false;
    }

    return true;
  }

  /**
   * 如果是指定url，不需要进行Referer校验
   *
   * @param request
   * @param response
   * @return
   */
  private boolean isWhiteURL(ServletRequest request, ServletResponse response) {
    return isLoginRequest(request, response) || pathsMatch(ROOT_URI, request) || pathsMatch(ACTUATOR_URI, request);
  }

  public CustomAccessControlFilter(List<String> domains, boolean csrfEnable) {
    super();
    this.domains = domains;
    this.csrfEnable = csrfEnable;
  }

}
